LSTM
一种常用的 循环神经网络(RNN) 模块,用于处理具有时序依赖特征的数据(如语音、文本、时间序列等)。每个时间步的公式化描述如下。
\(x_t\) : 当前时间步输入向量
\(h_{t-1}\) : 上一时间步的隐藏状态
\(c_{t-1}\) : 上一时间步的细胞状态
\(i_t, f_t, g_t, o_t\) : 四个门(输入门、遗忘门、候选门、输出门)
\(W_*\) : 对应的权重矩阵
\(b_*\) : 偏置项
\(\sigma(\cdot)\) : Sigmoid 函数
\(\odot\) : 元素乘
- 输入:
input - 输入序列数据,形状为 \((seq_len, batch, input_size)\),即每个时间步的输入特征。
weight_i - 输入到各门 \((input、forget、cell、output)\) 的权重矩阵,大小为 4 * hidden_size * input_size。
weight_h - 上一隐藏状态到各门的权重矩阵,大小为 \(4 * hidden_size * hidden_size\)
input_bias - 输入部分的偏置项,对应 4 个门的偏置。
state_bias - 隐藏状态部分的偏置项(也是 \(4 * hidden_size\)),与 input_bias 一起求和形成总偏置。
hidden_state - 当前批次初始隐藏状态输入( \(h₀\) ),执行后更新为最后时刻的隐藏状态输出( \(hₜ\))
cell_state - 当前批次初始细胞状态输入( \(c₀\)),执行后更新为最后时刻的细胞状态输出( \(cₜ\))。
buffer - 临时工作区指针数组(中间计算缓存,如门值、激活结果、临时矩阵等,用于优化性能)。
LstmParameter - LSTM 配置参数结构体,包含输入大小、隐藏层维度、序列长度、是否双向等信息。
core_mask - 核掩码(仅适用于共享存储版本)。
LstmParameter定义:
1typedef struct LstmParameter {
2int input_size_;//每个时间步输入向量的维度(输入特征数)。
3int hidden_size_;//LSTM 隐藏状态的维度(每个门的内部计算大小)。
4int project_size_;//投影层输出维度(用于 LSTMP,有则在输出前线性压缩隐藏状态)。
5int output_size_;//实际输出维度,等于 hidden_size_ 或 project_size_(取决于是否使用投影层)。
6int seq_len_;//输入序列的时间步数(序列长度)。
7int batch_;//批次大小(一次处理的样本数量)。
8// other parameter
9int output_step_;//指定输出第几个时间步的结果(通常为最后一步或每步)。
10bool bidirectional_;//是否为双向 LSTM(true 表示前向和后向各一层)。
11float zoneout_cell_;//单元状态的 Zoneout 比例(防止过拟合的正则化参数)。
12float zoneout_hidden_;//隐藏状态的 Zoneout 比例(防止过拟合)。
13int input_row_align_;//输入张量的行对齐参数(用于 DMA 或 SIMD 加速的内存对齐)。
14int input_col_align_;//输入张量的列对齐参数。
15int state_row_align_;//状态张量(hidden/cell)的行对齐参数。
16int state_col_align_;//状态张量的列对齐参数。
17int proj_col_align_;//投影层矩阵的列对齐参数。
18bool has_bias_;//是否包含偏置项(true 表示使用 bias)。
19} LstmParameter;
- 输出:
output - 计算结果地址,存放 LSTM 每个时间步输出结果的缓冲区,维度通常为 \((seq\_len, batch, output\_size)\)
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32
MT7004 支持fp32
共享存储版本:
-
void fp_Lstm_s(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, float *buffer[9], const LstmParameter *lstm_param, int core_mask)
C调用示例:
1//FT78NE示例 2#include <stdio.h> 3#include <lstm.h> 4 5int main(int argc, char* argv[]) { 6 LstmParameter *lstm_param = (LstmParameter *)0x90000000; 7 lstm_param->seq_len_ = 20; 8 lstm_param->batch_ = 1; 9 lstm_param->input_size_ = 2000; 10 lstm_param->hidden_size_ = 3; 11 lstm_param->bidirectional_ = false; 12 float * input = (float *)0xA0000000; //input在DDR空间 13 float * weight_i = (float *)0xA1000000; 14 float * weight_h = (float *)0xA3000000; 15 float *input_bias_ =(float *) 0xB0900000; 16 float * state_bias_ =(float *) 0xB0B00000; 17 float * output_s = (float *)0xC0000000; 18 float *hidden_state_s = (float *)0xC0100000; 19 float *cell_state_s = (float *)0xC0200000; 20 float *buffer[9]; 21 float * packed_input_ = (float *)0xB0000000; 22 buffer[0] = packed_input_; 23 float * gate = (float *)0xB0100000; 24 buffer[1] = gate; 25 float * packed_state = (float *)0xB0200000; 26 buffer[2] = packed_state; 27 float * state_gate = (float *)0xB0300000; 28 buffer[3] = state_gate; 29 float * cell_buffer = (float *)0xB0400000; 30 buffer[4] = cell_buffer; 31 float * hidden_buffer = (float *)0xB0500000; 32 buffer[5] = hidden_buffer; 33 float * packed_output = (float *)0xB0600000; 34 buffer[6] = packed_output; 35 float * left_matrix = (float *)0xB0700000; 36 buffer[7] = left_matrix; 37 float * packed_ptr = (float *)0xB0800000; 38 buffer[8] = packed_ptr; 39 int core_mask = 0xff; 40 fp_Lstm_s(output_s, input, weight_i, weight_h, input_bias_, 41 state_bias, hidden_state_s, cell_state_s, buffer, 42 lstm_param, core_mask); 43 return 0; 44}
私有存储版本:
-
void fp_Lstm_p(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, float *buffer[9], const LstmParameter *lstm_param)
C调用示例:
1//FT78NE示例 2#include <stdio.h> 3#include <lstm.h> 4int main(int argc, char* argv[]) { 5 LstmParameter *lstm_param = (LstmParameter *)0x10000000; 6 lstm_param->seq_len_ = 4; 7 lstm_param->batch_ = 1; 8 lstm_param->input_size_ = 2; 9 lstm_param->hidden_size_ = 3; 10 lstm_param->bidirectional_ = false; 11 float * input = (float *)0x10000200; //input在DDR空间 12 float * weight_i = (float *)0x10000400; 13 float * weight_h = (float *)0x10000600; 14 float *input_bias_ =(float *) 0x10000800; 15 float * state_bias_ =(float *) 0x10000A00; 16 float * output_s = (float *)0x10000C00; 17 float *hidden_state_s = (float *)0x10000E00; 18 float *cell_state_s = (float *)0x10001000; 19 float *buffer[9]; 20 float * packed_input_ = (float *)0x10001200; 21 buffer[0] = packed_input_; 22 float * gate = (float *)0x10001400; 23 buffer[1] = gate; 24 float * packed_state = (float *)0x10001600; 25 buffer[2] = packed_state; 26 float * state_gate = (float *)0x10001800; 27 buffer[3] = state_gate; 28 float * cell_buffer = (float *)0x10001A00; 29 buffer[4] = cell_buffer; 30 float * hidden_buffer = (float *)0x10001C00; 31 buffer[5] = hidden_buffer; 32 float * packed_output = (float *)0x10001F00; 33 buffer[6] = packed_output; 34 float * left_matrix = (float *)0x10002000; 35 buffer[7] = left_matrix; 36 float * packed_ptr = (float *)0x10002200; 37 buffer[8] = packed_ptr; 38 fp_Lstm_p(output_s, input, weight_i, weight_h, input_bias_, 39 state_bias, hidden_state_s, cell_state_s, buffer, 40 lstm_param); 41 return 0; 42}